#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import backtrader as bt
from backtrader.utils import AutoOrderedDict
from backtrader.utils.datehelper import *
import pandas as pd
import os

"""
净值
"""


class NetAssetValue(bt.TimeFrameAnalyzerBase):

    params = (
        ('timeframe', bt.TimeFrame.Days),
        ('compression', 1),
        ('fund', None),
        ('csv', False),
        ('out', None),
        ('rounding', 6)
    )

    def start(self):
        super(NetAssetValue, self).start()
        if self.p.fund is None:
            self._fundmode = self.strategy.broker.fundmode
        else:
            self._fundmode = self.p.fund

    # 当前周期已经over的时候就计算一次回撤
    def on_dt_over(self):
        if not self._fundmode:
            value = self.strategy.broker.getvalue()
        else:
            value = self.strategy.broker.fundvalue
        cash = self.strategy.broker.getcash()
        self.rets[self.dtkey] = [cash, value]

    def stop(self):
        self.on_dt_over()

        if self.p.csv and self.p.out:
            params = self.strategy.p.__dict__.copy()
            params['asset'] = self.data._name
            if self.p.out.startswith("f"):
                out = eval(self.p.out, params)
            else:
                out = self.p.out

            if not os.path.exists(out):
                os.makedirs(out)
            fpath = os.path.join(out, "nav.csv")
            self.to_dataframe().to_csv(fpath, index=False)

    def to_dataframe(self):
        cols = ["date", "symbol", "cash", "value"]
        data = []
        _skip_headers = False
        for date, v in self.rets.items():
            if _skip_headers:
                _skip_headers = False
                continue

            data.append([date2str(date), self.data._name] + v)

        df = pd.DataFrame(columns=cols, data=data)
        df = df.round(decimals=self.p.rounding)
        return df
